{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "EnergyDisaggregationDemo_View.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 2",
      "language": "python",
      "name": "python2"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "3BykKKaz_eE6",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# License\n",
        "\n",
        "Copyright 2019 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
        "you may not use this file except in compliance with the License.  \n",
        "You may obtain a copy of the License at . \n",
        "\n",
        "      http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing,  \n",
        "software distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
        "See the License for the specific language governing permissions and  \n",
        "limitations under the License."
      ]
    },
    {
      "metadata": {
        "id": "u35VyF5SwDVz",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Pre-work"
      ]
    },
    {
      "metadata": {
        "id": "x_zuXxzQjBb0",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Upload files (skip this if this is run locally)\n",
        "\n",
        "# Use this cell to update the following files\n",
        "#   1. requirements.txt\n",
        "#   2. e2e_demo_credential.json\n",
        "from google.colab import files\n",
        "uploaded = files.upload()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "XkBx4CguiA0Z",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Install missing packages\n",
        "\n",
        "# run this cell to install packages if some are missing\n",
        "!pip install -r ./requirements.txt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "XdolU7HIiA0c",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Import libraries\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "import json\n",
        "import numpy as np\n",
        "import os\n",
        "import pandas as pd\n",
        "import pandas_gbq\n",
        "import seaborn as sns\n",
        "import time\n",
        "import Queue as queue\n",
        "from google.cloud import pubsub_v1\n",
        "from IPython import display"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "11f98L1viA0g",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Configurations\n",
        "\n",
        "# project related\n",
        "GOOGLE_CLOUD_PROJECT = 'your-google-project-id' #@param\n",
        "GOOGLE_APPLICATION_CREDENTIALS = 'e2e_demo_credential.json' #@param\n",
        "os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = GOOGLE_APPLICATION_CREDENTIALS\n",
        "\n",
        "# data related\n",
        "DATASET_ID = 'EnergyDisaggregation'\n",
        "\n",
        "# pubsub related\n",
        "PRED_TOPIC = 'pred'\n",
        "SUB_NAME = 'sub1'\n",
        "DEVICE_ID = 'target-device-to-monitor' #@param"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "aDLEKuTtiA0j",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# BQ data loading"
      ]
    },
    {
      "metadata": {
        "id": "IgR9WcYViA0j",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Utility\n",
        "\n",
        "def get_appliance_info(project_id, dataset):\n",
        "    \"\"\"Get appliance info from the big query table.\n",
        "    \n",
        "    Load appliance info from big query table.\n",
        "    \n",
        "    Args:\n",
        "      project_id: str, google cloud project id.\n",
        "      dataset: str, name of the dataset.\n",
        "    Returns:\n",
        "      pandas.DataFrame, appliance info.\n",
        "    \"\"\"\n",
        "    res = pandas_gbq.read_gbq(\n",
        "        'SELECT * FROM {}.ApplianceInfo'.format(dataset),\n",
        "        project_id,\n",
        "        dialect='legacy')\n",
        "    return res\n",
        "\n",
        "def load_ground_truth(project_id, dataset, appliance_id):\n",
        "    \"\"\"Load ground-truth data of a specified appliance from big query table.\n",
        "\n",
        "    Load true appliance status from big query data.\n",
        "\n",
        "    Args:\n",
        "        project_id: str, google project id.\n",
        "        dataset: str, dataset name.\n",
        "        appliance_id: int, appliance id.\n",
        "    Returns:\n",
        "        pandas.DataFrame, ground truth appliance status.\n",
        "    \"\"\"\n",
        "    query = \"\"\"\n",
        "    SELECT * FROM {}.ApplianceStatusGroundTruth\n",
        "    WHERE appliance_id = {}\n",
        "    ORDER BY time\n",
        "    \"\"\".format(dataset, appliance_id)\n",
        "    return pandas_gbq.read_gbq(query,\n",
        "                               project_id,\n",
        "                               index_col='time',\n",
        "                               dialect='legacy')\n",
        "\n",
        "def load_data(project_id, dataset, app_ids):\n",
        "    \"\"\"Load ground-truth data from big query table.\n",
        "\n",
        "    Load true appliance status from big query data.\n",
        "\n",
        "    Args:\n",
        "        project_id: str, google project id.\n",
        "        dataset: str, dataset name.\n",
        "        app_ids: list, appliances' ids.\n",
        "    Returns:\n",
        "        dict, {appliance_id: ground truth data (pandas.DataFrame)}\n",
        "    \"\"\"\n",
        "    print('Loading ground truth data ...')\n",
        "    gt = {}\n",
        "    for k in app_ids:\n",
        "        gt[k] = load_ground_truth(project_id, dataset, k)\n",
        "    print('Data loaded.')\n",
        "    return gt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "sLijm-ldiA0m",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Load appliance info\n",
        "\n",
        "app_info = get_appliance_info(GOOGLE_CLOUD_PROJECT, DATASET_ID)\n",
        "app_info"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "MSguFEpsiA0o",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Create appliance Id to name mapping\n",
        "\n",
        "app_id_name_map = {row[1]: row[0] for i, row in app_info.iterrows()}\n",
        "app_id_name_map"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "dmzRNLX8iA0r",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Load ground truth data (this can take some time)\n",
        "ground_truth = load_data(project_id=GOOGLE_CLOUD_PROJECT,\n",
        "                         dataset=DATASET_ID,\n",
        "                         app_ids=app_id_name_map.keys())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "kZor4sdbiA0t",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Start subscriber"
      ]
    },
    {
      "metadata": {
        "id": "Fc6e4XHHiA0z",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "When the following cell is run, it will pull messages from a pub/sub topic.  \n",
        "The message contains both raw power readings and prediction results from CMLE.  \n",
        "It plots the raw active powers, it also updates precision/recall for each appliance.  "
      ]
    },
    {
      "metadata": {
        "id": "mA2aValZiA0v",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Utility\n",
        "\n",
        "class MsgProcesser(object):\n",
        "    \"\"\"Subscribe to a pub/sub topic and process incoming messages.\"\"\"\n",
        "    \n",
        "    def __init__(self,\n",
        "                 project_id,\n",
        "                 ground_truth,\n",
        "                 topic_name,\n",
        "                 subscription_name,\n",
        "                 app_id_name_map,\n",
        "                 target_device):\n",
        "        # member initialization\n",
        "        self._gt = ground_truth\n",
        "        self._data = pd.Series(np.zeros(self._gt[0].shape[0]),\n",
        "                                index=self._gt[0].index)\n",
        "        self._app_id_to_name = app_id_name_map\n",
        "        self._target_device = target_device\n",
        "        self._app_names = [app_id_name_map[k]\n",
        "                           for k in range(len(app_id_name_map.keys()))]\n",
        "        self._queue = queue.Queue(maxsize=100000)\n",
        "        # create subsciprtion\n",
        "        self._subscriber, self._subscription_path = (\n",
        "            self.create_subscription(project_id, topic_name, subscription_name))\n",
        "        self._subscriber.subscribe(self._subscription_path,\n",
        "                                   callback=self._msg_callback)\n",
        "        \n",
        "    def create_subscription(self, project_id, topic_name, subscription_name):\n",
        "        \"\"\"Create a subscription in pub/sub.\n",
        "    \n",
        "        Before listening to incoming messages, we need to define a subscription.\n",
        "        This function creates a new subscription only if it does not exist.\n",
        "\n",
        "        Args:\n",
        "          project_id: str, google cloud project id.\n",
        "          topic_name: str, topic name.\n",
        "          subscription_name: str, name of the new subscription.\n",
        "        Returns:\n",
        "          (google.cloud.pubsub_v1.SubscriberClient, topic_path)\n",
        "        \"\"\"\n",
        "        print('Creating subscription \"{}\" to topic \"{}\" ...'.format(\n",
        "            subscription_name, topic_name))\n",
        "        try:\n",
        "            subscriber = pubsub_v1.SubscriberClient()\n",
        "            topic_path = subscriber.topic_path(project_id, topic_name)\n",
        "            subscription_path = subscriber.subscription_path(\n",
        "                project_id, subscription_name)\n",
        "            # if the subscription exists, exception is raised here\n",
        "            subscription = subscriber.create_subscription(\n",
        "                subscription_path, topic_path)\n",
        "            print('Subscription created: {}'.format(subscription))\n",
        "        except Exception as e:\n",
        "            print('Subscription \"{}\" existed.'.format(subscription_name))\n",
        "        return subscriber, subscription_path\n",
        "    \n",
        "    def async_pull_msg(self):\n",
        "        \"\"\"Pull messages asynchronously from pub/sub topic.\"\"\"\n",
        "        # initialize metrics\n",
        "        self._metrics = [{'TP': 0, 'TN': 0, 'FP': 0, 'FN': 0}\n",
        "                         for k in range(len(self._app_id_to_name))]\n",
        "        precisions = [0. for k in self._app_id_to_name]\n",
        "        recalls = [0. for k in self._app_id_to_name]\n",
        "        # initialize UI\n",
        "        fig, ax = plt.subplots(1, 1, figsize=(16, 4))\n",
        "        sns.axes_style('white')\n",
        "        # receive data and update UI\n",
        "        print('Listening for messages on {} ...'.format(\n",
        "          self._subscription_path))\n",
        "        max_t = None\n",
        "        while(True):\n",
        "            # collect data from callback threads\n",
        "            t, d, probs = self._queue.get()\n",
        "            if max_t is None:\n",
        "                max_t = t\n",
        "            else:\n",
        "                max_t = max(t, max_t)\n",
        "            self._data.at[t] = d\n",
        "            for k in self._app_id_to_name:\n",
        "                truth = self._gt[k].status.at[t]\n",
        "                pred = 1 if probs[k] >= 0.5 else 0\n",
        "                if truth == 1 and pred == 1:\n",
        "                    self._metrics[k]['TP'] += 1\n",
        "                elif truth == 0 and pred == 0:\n",
        "                    self._metrics[k]['TN'] += 1\n",
        "                elif truth == 1 and pred == 0:\n",
        "                    self._metrics[k]['FN'] += 1\n",
        "                else:\n",
        "                    self._metrics[k]['FP'] += 1\n",
        "                TP = self._metrics[k]['TP']\n",
        "                FP = self._metrics[k]['FP']\n",
        "                FN = self._metrics[k]['FN']\n",
        "                precisions[k] = (\n",
        "                  np.nan if TP + FP == 0 else round(1.0 * TP / (TP + FP), 2))\n",
        "                recalls[k] = (\n",
        "                  np.nan if TP + FN == 0 else round(1.0 * TP / (TP + FN), 2))\n",
        "                \n",
        "            score_matrix = pd.DataFrame({'Precision': precisions,\n",
        "                                         'Recall': recalls},\n",
        "                                        index=self._app_names)\n",
        "                        \n",
        "            # update view in the main thread\n",
        "            ax.clear()\n",
        "            x = max_t.astype('datetime64[h]')\n",
        "            mask = (self._data.index >= x) & (self._data.index < x + 1)\n",
        "            ax = sns.lineplot(data=self._data[mask], linewidth=2.5, ax=ax)\n",
        "            ax.set_xlabel('time')\n",
        "            ax.set_ylabel('active power')\n",
        "            ax.set_ylim(bottom=0)\n",
        "            title = 'Device: {}, Date: {} (UTC)'.format(\n",
        "              self._target_device, max_t.astype('datetime64[D]'))\n",
        "            ax.set_title(title)\n",
        "            display.clear_output(wait=True)\n",
        "            display.display(plt.gcf())\n",
        "            display.display(score_matrix)\n",
        "            \n",
        "    def _msg_callback(self, message):\n",
        "        \"\"\"Pub/sub pull callback.\"\"\"\n",
        "        try:\n",
        "            data = json.loads(message.data.decode('utf-8'))\n",
        "            device_id = data['device_id']\n",
        "            if device_id == self._target_device:\n",
        "              t = np.datetime64(data['time'][-1])\n",
        "              d = data['data'][-1]\n",
        "              probs = data['probs']\n",
        "              self._queue.put((t, d, probs))\n",
        "        except Exception as e:\n",
        "            print('Error: {}'.format(e))\n",
        "        message.ack()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "DSUQgFu5iA00",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# @title Visualization\n",
        "\n",
        "tt = MsgProcesser(project_id=GOOGLE_CLOUD_PROJECT,\n",
        "                  ground_truth=ground_truth,\n",
        "                  topic_name=PRED_TOPIC,\n",
        "                  subscription_name=SUB_NAME,\n",
        "                  app_id_name_map=app_id_name_map,\n",
        "                  target_device=DEVICE_ID)\n",
        "tt.async_pull_msg()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "85Ty2zNBiA02",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}